from cProfile import label

from fileinput import filename
from genericpath import isdir, isfile
from operator import truediv
from turtle import update
from scipy.io import savemat, loadmat
import scipy.sparse as sparse

import numpy as np

from utils import train_model, test_model, get_needed_dirs
from data_utils import get_dataset, vstack, get_data_params

import matplotlib.pyplot as plt
import os
import argparse

import warnings
import time
import pickle
from existing_attacks import kkt_attacks, mta_attack, min_max_attack

warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mnist_17',help="options: 2d_toy, mnist_17, dogfish, cifar10_89")
parser.add_argument('--model_type',default='svm',help='victim model type: SVM or rlogistic regression')
parser.add_argument('--weight_decay',default=0.09, type=float, help='weight decay for regularizers')
parser.add_argument('--use_custom_svm',action="store_true", help='print the detailed information')
parser.add_argument('--rand_seed',default=1234, type=int, help='seed for random number generator')
parser.add_argument('--epsilon',default=0.03, type=float, help='poisoning ratio')
parser.add_argument('--obtained_epsilon',default=0.1, type=float, help='poisoning ratio that we have results for, useful for continuously running the attack')
parser.add_argument('--print_every',default=50, type=int, help='print the attack statistics every n iterations')

parser.add_argument('--flip_y',default=0.01, type=float, help='fraction of label noise')
parser.add_argument('--class_sep',default=2.0, type=float, help='class separability')
parser.add_argument('--addi_search_space',default=2.0, type=float, help='additional search space for the attack, added to min,max value of clean points')
parser.add_argument('--n_samples',default=100000, type=int, help='number of total samples')
parser.add_argument('--n_features',default=1, type=int, help='number of features for the toy dataset')
parser.add_argument('--num_sampled_points',default=50, type=int, help='how many samples to generate for grid search')
parser.add_argument('--train_frac',default=0.5, type=float, help='fraction of train data in all samples')

# attack params for algorithmic influence maximization method
parser.add_argument('--init_mode', default='current',help="options: random [randomly generate init point], current [randomly select from existing data]")
# parser.add_argument('--baseline_influence',action="store_true", help='the baseline influence attack')
parser.add_argument('--num_restart',default=10, type=int, help='how many restarts for optimizing the lr influence maximization')
parser.add_argument('--lr',default=0.01, type=float, help='learning rate when performing gradient ascend')
parser.add_argument('--num_opt_steps',default=100, type=int, help='optimization steps for lr influence maximization')
parser.add_argument('--optimizer', default='adam',help="supported optimizer options: 'gd', 'adagrad', 'adam'")
parser.add_argument('--attack', default='mta',help="attack options: 'min_max', 'mta', 'kkt', 'all' ")
parser.add_argument('--approx_type', default='exact',help="hvp computation options: 'lissa', 'exact', 'identity' ")
# parser.add_argument('--target_model_type', default='None',help="types of target models generated: 'exhaust', 'error_highL','error_lowL','None'")
parser.add_argument('--target_error',default=0.05, type=float, help='error rate for the target model')

parser.add_argument('--generate_target',action="store_true", help='generate the target models first!')
parser.add_argument('--low_poison_budget',action="store_true", help='low poison regime and show enumarated results!')
parser.add_argument('--use_train',action="store_true", help='test data will be leveraged in the attack process')
parser.add_argument('--original',action="store_true", help='we will leverage the original target model generation process, not improved one')
parser.add_argument('--grad_calib_mode',default="orig", help='grad_calibration options: orig,conf,norm')
parser.add_argument('--minimize',action="store_true", help='minimize influence w.r.t. flipped label')

parser.add_argument('--if_rep_num',default=1, type=int, help='repeat n times to give better estimate on the induced model weight')
parser.add_argument('--check_model_update_every',default=0, type=int, help='how frequently to check model update, 0 means no model update!')
parser.add_argument('--test_single',action="store_true", help='test single poison optimality!')
parser.add_argument('--load_mat_target',action="store_true", help='load the target model from the mat file, not needed anymore!')
parser.add_argument('--C',default=0.0, type=float, help='loss threshold')
parser.add_argument('--lr_min_C',default=0.21, type=float, help='minimum loss threshold for logistic regression')
parser.add_argument('--min_max_lr_matlab',default=0.03, type=float, help='leraning rate for origi,nal min-max attack')

parser.add_argument('--no_target',action="store_true", help='no target model is needed for max loss attack')
# parser.add_argument('--small_epsilon',action="store_true", help='small epsilon to generate more practical results')
parser.add_argument('--check_transfer',action="store_true", help='check the transferability of different models')
parser.add_argument('--check_enron',action="store_true", help='check vulnerability of truncated enron!')
parser.add_argument('--ldc_compare',default=1.0,type=float, help='compare to ldc by using larger poisoning fraction')
parser.add_argument('--use_slab',action="store_true", help='use oracale slab defense')
parser.add_argument('--use_sphere',action="store_true", help='use sphere defese')
parser.add_argument('--percentile', default=90)

args = parser.parse_args()
print(args)

def main(args):
    if args.dataset == 'imdb':
        args.weight_decay = 0.01
    np.random.seed(args.rand_seed)
    if args.dataset == 'cifar10_89':
        epochs = [-1,1,50,90,100,120]
    else:
        epochs = [0]
    for epoch in epochs:
        args.epoch = epoch
        X_train,Y_train,X_test,Y_test,x_lims = get_dataset(args,epoch) 
        if args.ldc_compare != 1.0:
            x_min, x_max = x_lims
            x_lims = [args.ldc_compare * x_min, args.ldc_compare * x_max]

        # get the class data so as to better evaluate impact of defenses
        percentile = int(np.round(float(args.percentile)))
        class_map, centroids, centroid_vec, sphere_radii, slab_radii = get_data_params(
            X_train,
            Y_train,
            percentile=percentile) 

        defense_pars = [class_map, centroids, centroid_vec, sphere_radii, slab_radii]

        print("--- Train/Test Data Size --- ")
        print(X_train.shape,Y_train.shape,X_test.shape,Y_test.shape)
        print("Search Space: Min {}, Max {}".format(x_lims[0],x_lims[1]))

        print("--- Performance of Clean Models --- ")
        clean_model = train_model(X_train,Y_train,args)       
        total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,Y_train,X_train,Y_train,\
        X_test,Y_test,clean_model,args,verbose=True)
        clean_data = [X_train,Y_train,X_test,Y_test]

        num_iter = int(args.epsilon*X_train.shape[0])
        use_test = not args.use_train
        num_points_test = 3

        # fix the default params for each dataset so that easy to run with few params
        if args.dataset == 'mnist_17':
            args.lr = 0.1
            args.num_opt_steps = 200
            args.num_restart = 10
        elif args.dataset == 'dogfish':
            args.lr = 0.1
            args.num_opt_steps = 3000
            args.num_restart = 10
        elif args.dataset in ['enron','filtered_enron'] and args.model_type == 'lr':
            args.weight_decay = 0.01

        compare_sota_attacks(args,x_lims,num_iter,num_points_test=num_points_test,clean_data=clean_data,use_test=use_test,defense_pars=defense_pars)

def compare_sota_attacks(args,x_lims,num_iter,num_points_test,clean_data,use_test=True,defense_pars=None):
    # :save_figs: whether to save the figs comparing different estimated influence for greedy optimal attack
    # get the greedy optimal attack results
    X_train,Y_train,X_test,Y_test = clean_data
    
    if sparse.issparse(X_train):
        is_sparse = True
    else:
        is_sparse = False
    # prepare the clean models
    curr_model = train_model(X_train,Y_train,args)
    clean_theta, clean_bias = curr_model.coef_.reshape(-1), curr_model.intercept_[0]

    if not args.use_train:
        use_train_or_test = 'use_test'
    else:
        use_train_or_test = 'use_train'
        
    total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,\
        Y_train,X_train,Y_train,X_test,Y_test,curr_model,args,verbose=False)

    actual_clean_train_err = np.copy(1-clean_train_acc)
    actual_clean_train_loss = np.copy(clean_train_loss)
    actual_clean_test_err = np.copy(1-clean_test_acc)
    actual_clean_test_loss = np.copy(clean_test_loss)

    epsilon = args.epsilon
    # get the relevant dirs
    result_dir, fig_dir_name,tar_model_dir, tar_model_dir_mat = get_needed_dirs(args,X_train)
    if args.ldc_compare != 1.0:
        result_dir = '{}/ldc_scale_{}'.format(result_dir,args.ldc_compare)
        if not os.path.isdir(result_dir):
            os.makedirs(result_dir)

    if args.use_sphere or args.use_slab:
        # currently, only consider both defenses, can distinguish further if needed
        assert args.use_slab and args.use_sphere
        result_dir = '{}/oracle_defense'.format(result_dir)
        if not os.path.isdir(result_dir):
            os.makedirs(result_dir)        

    if args.low_poison_budget:
        low_poison_eps = 0.03
        fixed_num_iter = int(low_poison_eps*X_train.shape[0])
        interest_eps = np.array([(a+1)/X_train.shape[0] for a in range(fixed_num_iter)])
        interest_eps_ids = np.arange(fixed_num_iter)
    else:
        interest_eps = [0.03] # [0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,0.15,0.2,0.25,0.3]
        interest_eps_ids = np.array([int(X_train.shape[0]*a)-1 for a in interest_eps])

    print("----- Conduct Epsilon-{:.4f} poison attack ----- ".format(args.epsilon))
    # mta, max-loss, min_max attacks require target models
    tar_model = train_model(X_train,Y_train,args)
    
    if args.load_mat_target or args.no_target:
        target_errors = [-1.0]
    else:
        if args.dataset == 'mnist_17':
            target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5]
        elif args.dataset == 'dogfish':
            if use_train_or_test == 'use_test':
                target_errors = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8]
            else:
                target_errors = [0.03,0.05,0.07,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.8]
                # target_errors = [0.7,0.8]
        elif args.dataset == 'mnist_69':
            target_errors = [0.03,0.04,0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
        elif args.dataset == 'mnist_38':
            target_errors = [0.04,0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
        elif args.dataset == 'mnist_49':
            target_errors = [0.05,0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
            if args.model_type == 'lr':
                target_errors = [0.06,0.07,0.09,0.11,0.12,0.15,0.17,0.2,0.23,0.25,0.3,0.35,0.4,0.45,0.5,0.55]
        elif args.dataset == 'cifar10_05':
            target_errors = [0.03,0.05,0.07,0.45,0.5,0.55]
        elif args.dataset == 'cifar10_14':
            target_errors = [0.03,0.05,0.07,0.45, 0.55]
        elif args.dataset in ['enron','filtered_enron']:
            target_errors = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7]
        elif args.dataset == 'imdb':
            target_errors = [0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7]
        elif args.dataset == 'adult':
            target_errors = [0.23,0.25,0.3,0.33,0.35,0.4,0.45,0.5,0.55,0.6]
            if args.model_type == 'lr':
                target_errors.append(0.65)
        elif args.dataset == 'cifar10_89':
            # because cifar10_89 has many datasets, we directly load the target errors.
            target_error_fname = '{}/generated_errors.npy'.format(tar_model_dir)
            target_errors = np.load(target_error_fname)

    # record the best result from batch of target models tested
    # we also record the lowest curr and target loss observed
    if args.check_enron:
        num_iter = int(0.3*4137)

    best_mta_train_loss = (-1)*np.ones(num_iter)
    best_mta_train_error = (-1)*np.ones(num_iter)
    best_mta_test_loss = (-1)*np.ones(num_iter)
    best_mta_test_error = (-1)*np.ones(num_iter)
    best_mta_target_error_loc = np.zeros(num_iter)
    best_mta_curr_loss = np.zeros(num_iter)
    best_mta_tar_loss = np.zeros(num_iter)

    best_max_loss_train_loss = (-1)*np.ones(num_iter)
    best_max_loss_train_error = (-1)*np.ones(num_iter)
    best_max_loss_test_loss = (-1)*np.ones(num_iter)
    best_max_loss_test_error = (-1)*np.ones(num_iter)
    best_max_loss_target_error_loc = np.zeros(num_iter)
    best_max_loss_curr_loss = np.zeros(num_iter)
    best_max_loss_tar_loss = np.zeros(num_iter)   

    # we only record the kkt epsilons we are interested in
    kkt_epsilons = [args.epsilon] # [0.001,0.002,0.003,0.005,0.007,0.009,0.01,0.02,0.03,0.05,0.1,0.15,0.3]

    best_kkt_test_errors = {}
    best_kkt_test_losses = {}
    best_kkt_train_errors = {}
    best_kkt_train_losses = {}
    best_target_errors_for_kkt = {}

    best_min_max_test_errors = {}
    best_min_max_test_losses = {}
    best_min_max_train_errors = {}
    best_min_max_train_losses = {}
    best_target_errors_for_min_max = {}    

    for eps in kkt_epsilons:
        best_kkt_test_errors[eps] = -1
        best_min_max_test_errors[eps] = -1

    if args.check_enron:
        max_error = -1e10

    mta_skip_flag = True # will set to false if there are target models available to attack
    for target_error in target_errors:
        if args.load_mat_target:
            # in case we use target model from matlab file.# is deprecated and will not be used.
            if not args.use_train:
                tar_model_fname = '{}/{}_thetas_with_bias_exact_decay_09_v3_prune.mat'.format(tar_model_dir_mat,args.dataset)
            else:
                tar_model_fname = '{}/{}_thetas_with_bias_exact_decay_09_use_train_v3_prune.mat'.format(tar_model_dir_mat,args.dataset)
        else:
            if not args.original:
                tar_model_fname = '{}/improved_best_theta_whole_err-{}'.format(tar_model_dir,target_error)
            else:
                tar_model_fname = '{}/orig_best_theta_whole_err-{}'.format(tar_model_dir,target_error)

        if args.no_target:            
            thetas = [None]
            biases = [None]
        else:
            # load the target model related stuffs
            if os.path.isfile(tar_model_fname):
                if args.load_mat_target:
                    f = loadmat(tar_model_fname)
                    num_thetas = f['thetas'].shape[0]
                    num_features = f['thetas'][0][0].shape[0]
                    thetas = np.zeros((num_thetas, num_features))
                    biases = np.zeros(num_thetas)
                    test_errors = np.zeros(num_thetas)
                    for i in range(num_thetas):
                        thetas[i, :] = f['thetas'][i][0][:, 0]
                        biases[i] = f['biases'][i][0][0][0]
                        test_errors[i] = f['test_errors'] [i][0]
                else:
                    file_to_read = open(tar_model_fname,"rb")
                    f = pickle.load(file_to_read)
                    best_target_theta = f['best_theta']
                    best_target_bias = f['best_bias']
                    tar_model_train_loss = f['best_train_loss'] 
                    tar_model_train_err = f['best_train_error'] 
                    tar_model_test_loss = f['best_test_loss'] 
                    tar_model_test_err = f['best_test_error'] 
                    print("---- Loaded Target Model Info----") 
                    print("Train Error: {:.5f}, Train Loss: {:.5f}, Test Error: {:.5f}, Test Loss: {:.5f}".format(tar_model_train_err,tar_model_train_loss,\
                        tar_model_test_err,tar_model_test_loss))
                    thetas = [best_target_theta]
                    biases = [best_target_bias]
                tar_exist_flag = True
            else:
                print("Target Model File {} does not exist!".format(tar_model_fname))
                tar_exist_flag = False
        
        if tar_exist_flag:    
            if args.load_mat_target:
                tar_model_type = 'original'
            else:
                if not args.original:
                    tar_model_type = 'improved'
                else:
                    tar_model_type = 'original'

            for iiii in range(len(thetas)):
                theta = thetas[iiii]
                bias = biases[iiii] 

                print(X_train.shape,theta.shape)

                tar_model.coef_= np.array([theta])
                if args.no_target:
                    tar_model = None
                else:
                    if args.load_mat_target:
                        tar_model.intercept_ = np.array([bias])
                    else:
                        tar_model.intercept_ = bias
                    print("------- Ideal Target Model Performance (Sanity Check ------- ")
                    total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, tar_model_clean_test_acc = test_model(X_train,\
                        Y_train,X_train,Y_train,X_test,Y_test,tar_model,args,verbose=True)

                if args.attack in ['mta','all']:
                    assert not args.no_target
                    if True:
                        mta_skip_flag = False
                        if args.check_enron:
                            num_iter = int(0.3 * 4137)
                            clean_set_size = 4137
                        else:
                            clean_set_size = X_train.shape[0]
                        if args.load_mat_target:
                            target_error = test_errors[iiii]
                            mta_result_dir = '{}/mta_results/mat_target'.format(result_dir)
                            result_fname = '{}/mta_{}_of_{}_{}_target_model_err-{:.5f}_{}_results'.format(mta_result_dir,num_iter,clean_set_size,tar_model_type,target_error,use_train_or_test)
                        else:
                            mta_result_dir = '{}/mta_results'.format(result_dir)
                            result_fname = '{}/mta_{}_of_{}_{}_target_model_err-{}_{}_results'.format(mta_result_dir,num_iter,clean_set_size,tar_model_type,target_error,use_train_or_test)
                        if not os.path.isdir(mta_result_dir):
                            os.makedirs(mta_result_dir)
                        
                        print(result_fname)
                        if os.path.isfile(result_fname):
                            print("directly loading results of mta of target error {:.5f}".format(target_error))
                            file_to_read = open(result_fname,"rb")
                            results_dict = pickle.load(file_to_read)
                            file_to_read.close()
                            mta_results = results_dict['mta_results'] 
                            mta_curr_and_tar_losses = results_dict['mta_loss_on_curr_and_tar']
                        else:
                            print("running the mta attack with {} model of target error {:.5f}".format(tar_model_type,target_error))
                            mta_results, mta_curr_and_tar_losses = mta_attack(args,curr_model,tar_model,x_lims,num_iter,clean_data,\
                                use_slab=args.use_slab,use_sphere=args.use_sphere,defense_pars=defense_pars)
                            results_dict = {}
                            results_dict['mta_results'] = mta_results
                            results_dict['mta_loss_on_curr_and_tar'] = mta_curr_and_tar_losses
                            file_to_write = open(result_fname, "wb")
                            pickle.dump(results_dict, file_to_write)
                            file_to_write.close()

                        # process the mta results to ensure that we produce best result using exhautive search on target models
                        tmp_mta_train_loss = mta_results[-4]
                        tmp_mta_train_error = mta_results[-3]
                        tmp_mta_test_loss = mta_results[-2]
                        tmp_mta_test_error = mta_results[-1]

                        # we are interested in finding target models that maximize test error
                        update_ids = tmp_mta_test_error > best_mta_test_error
                        best_mta_train_loss[update_ids] = tmp_mta_train_loss[update_ids]
                        best_mta_train_error[update_ids] = tmp_mta_train_error[update_ids]
                        best_mta_test_loss[update_ids] = tmp_mta_test_loss[update_ids]
                        best_mta_test_error[update_ids] = tmp_mta_test_error[update_ids]
                        best_mta_target_error_loc[update_ids] = target_error
                        best_mta_curr_loss[update_ids] = mta_curr_and_tar_losses[0][update_ids]
                        best_mta_tar_loss[update_ids] = mta_curr_and_tar_losses[1][update_ids]

                        # check the performance of truncated enron dataset
                        if args.check_enron:
                            poison_X = mta_results[0]
                            poison_Y = mta_results[1]
                            chosen_poison_X = poison_X[:int(0.03*X_train.shape[0])]
                            chosen_poison_Y = poison_Y[:int(0.03*X_train.shape[0])]
                            print(chosen_poison_X.shape,chosen_poison_Y.shape,X_train.shape,Y_train.shape)
                            full_X_p = vstack(X_train,chosen_poison_X)
                            full_Y_p = vstack(Y_train,chosen_poison_Y)
                            # train poisoned model from scratch
                            p_model = train_model(full_X_p,full_Y_p,args)
                            tmp_acc = p_model.score(X_test,Y_test)
                            print("poisoned data shape",poison_X.shape,poison_Y.shape,\
                            chosen_poison_X.shape,chosen_poison_Y.shape)
                            print("clean and whole data shape",X_train.shape,Y_train.shape,\
                            full_X_p.shape,full_Y_p.shape,X_test.shape,Y_test.shape)
                            if 1-tmp_acc > max_error:
                                max_error = 1-tmp_acc
                                print("current max test error:",max_error)

                if args.attack in ['kkt','all']:
                    for kkt_epsilon in kkt_epsilons:
                        if kkt_epsilon <= 0.005:
                            epsilon_increment = 0.0005
                        # elif kkt_epsilon <= 0.05:
                        #     epsilon_increment = 0.0005
                        else:
                            epsilon_increment = 0.005

                        kkt_num_iter = int(kkt_epsilon * X_train.shape[0])
                        if args.load_mat_target:
                            target_error = test_errors[iiii]
                            kkt_result_dir = '{}/kkt_results/mat_target'.format(result_dir)
                            result_fname = '{}/kkt_{}_of_{}_{}_target_model_err-{:.5f}_{}_results'.format(max_loss_result_dir,kkt_num_iter,X_train.shape[0],\
                                tar_model_type,target_error,use_train_or_test)
                        else:
                            kkt_result_dir = '{}/kkt_results'.format(result_dir)
                            result_fname = '{}/kkt_{}_of_{}_{}_target_model_err-{}_{}_results'.format(kkt_result_dir,\
                            kkt_num_iter,X_train.shape[0],tar_model_type,target_error,use_train_or_test)
                        if not os.path.isdir(kkt_result_dir):
                            os.makedirs(kkt_result_dir)
                        print(result_fname)
                        if os.path.isfile(result_fname):
                            print("directly loading results of KKT of target error {:.5f}, epsilon {}!".format(target_error,kkt_epsilon))
                            file_to_read = open(result_fname,"rb")
                            results_dict = pickle.load(file_to_read)
                            file_to_read.close()
                            kkt_results = results_dict['kkt_results'] 
                            best_kkt_results = results_dict['best_kkt_results']
                        else:    
                            print("running the KKT of target error {:.5f}, epsilon {}!".format(target_error,kkt_epsilon))
                            kkt_results, best_kkt_results = kkt_attacks(args,x_lims,kkt_epsilon,\
                            clean_data, tar_model, epsilon_increment=epsilon_increment, verbose=True)
                            results_dict = {}
                            results_dict['kkt_results'] = kkt_results
                            results_dict['best_kkt_results'] = best_kkt_results
                            file_to_write = open(result_fname, "wb")
                            pickle.dump(results_dict, file_to_write)
                            file_to_write.close()

                        #best_x, best_y, all_thetas,train_victim_losses,train_01_losses,test_victim_losses,test_01_losses = kkt_results
                        tmp_kkt_train_loss = best_kkt_results[-4]
                        tmp_kkt_train_error = best_kkt_results[-3]
                        tmp_kkt_test_loss = best_kkt_results[-2]
                        tmp_kkt_test_error = best_kkt_results[-1]
                        
                        if tmp_kkt_test_error > best_kkt_test_errors[kkt_epsilon]:
                            best_kkt_test_errors[kkt_epsilon] = tmp_kkt_test_error
                            best_kkt_train_errors[kkt_epsilon] = tmp_kkt_train_error
                            best_kkt_test_losses[kkt_epsilon] = tmp_kkt_test_loss
                            best_kkt_train_losses[kkt_epsilon] = tmp_kkt_train_loss 
                            best_target_errors_for_kkt[kkt_epsilon] = target_error                       

                if args.attack in ['min_max','all']:
                    # The original Min-Max Attack
                    num_sgd_steps = 1
                    min_max_lr_matlab = args.min_max_lr_matlab

                    C = args.C
                    if args.model_type == 'lr':
                        # log loss cannot have values smaller than 0, 
                        # and so 0.21 seems to be a magic number that works
                        C = max(C,args.lr_min_C)

                    for ii in range(len(kkt_epsilons)):
                        _epsilon = kkt_epsilons[ii]
                        if args.dataset in ['mnist_17','mnist_38','mnist_69','cifar10_05']:
                            burn_frac = max(0.33, 0.02/_epsilon-1) 
                        elif args.dataset == 'dogfish':
                            burn_frac = max(1.0, 0.10/_epsilon-1)
                        elif args.dataset == 'enron':
                            burn_frac = max(1.0, 0.10/_epsilon-1)
                        else:
                            burn_frac = 0.0
                        _num_iter = int(_epsilon * X_train.shape[0])
                        print("Orig Min Max Attack with {} poisons of {}".format(_num_iter,X_train.shape[0]))
                        min_max_result_dir = '{}/min_max_results'.format(result_dir)
                        if not os.path.isdir(min_max_result_dir):
                            os.makedirs(min_max_result_dir)

                        result_fname = '{}/min_max_{}_of_{}_{}_target_model_err-{:.5f}_SGDSteps-{}_lr-{}_burnin-{}_C-{}_{}_results'.format(min_max_result_dir,_num_iter,X_train.shape[0],\
                            tar_model_type,target_error, num_sgd_steps,min_max_lr_matlab,burn_frac,C,use_train_or_test)
                        print(result_fname)

                        if os.path.isfile(result_fname):
                            print("directly loading results of original min-max attack of target error {:.5f}, C {}, epsilon {}!".format(target_error,\
                                C,_epsilon))
                            file_to_read = open(result_fname,"rb")
                            results_dict = pickle.load(file_to_read)
                            file_to_read.close()
                            min_max_results = results_dict['min_max_results'] 
                        else:
                            print("running the original min-max attack of target error {:.5f}, C {}, epsilon {}!".format(target_error,C,_epsilon))
                            min_max_results = min_max_attack(_epsilon,args,x_lims,_num_iter,clean_data,\
                                num_sgd_steps=num_sgd_steps,min_max_lr_matlab=min_max_lr_matlab,burn_frac=burn_frac,\
                                    tar_model=tar_model,C=C,use_slab=args.use_slab,use_sphere=args.use_sphere,defense_pars=defense_pars)
                            results_dict = {}
                            results_dict['min_max_results'] = min_max_results
                            file_to_write = open(result_fname, "wb")
                            pickle.dump(results_dict, file_to_write)
                            file_to_write.close()

                        min_max_train_loss = min_max_results[-4][-1]
                        min_max_train_error = min_max_results[-3][-1]
                        min_max_test_loss = min_max_results[-2][-1]
                        min_max_test_error = min_max_results[-1][-1] 

                        if min_max_test_error > best_min_max_test_errors[_epsilon]:
                            best_min_max_test_errors[_epsilon] = min_max_test_error
                            best_min_max_train_errors[_epsilon] = min_max_train_error
                            best_min_max_test_losses[_epsilon] = min_max_test_loss
                            best_min_max_train_losses[_epsilon] = min_max_train_loss 
                            best_target_errors_for_min_max[_epsilon] = target_error   

    if args.attack in ['mta','all'] and not mta_skip_flag:
        # print("mta: final test error: {}, final test victim loss: {}".format(best_01_losses[-1],best_victim_losses[-1]))
        print("--- mta Attack Summary -----")
        print("[Final] Train Loss: {}, Train error: {}, Test Loss: {}, Test error: {}".format(best_mta_train_loss[-1],best_mta_train_error[-1],\
            best_mta_test_loss[-1],best_mta_test_error[-1]))
        unique, counts = np.unique(best_mta_target_error_loc, return_counts=True)
        print("Best Error Locations Values and Counts of mta Attack")
        print(np.asarray((unique, counts)).T)

    if args.attack in ['kkt','all']:
        for kkt_epsilon in kkt_epsilons:
            print("---- kkt attack summary of Eps {} ----".format(kkt_epsilon))
            print("[Final] Train Loss: {}, Train Error: {}, Test Loss: {}, Test Error: {}".format(best_kkt_train_losses[kkt_epsilon],\
            best_kkt_train_errors[kkt_epsilon], best_kkt_test_losses[kkt_epsilon],best_kkt_test_errors[kkt_epsilon]))
            print("Best Error Locations of KKT Attack:")
            print(best_target_errors_for_kkt[kkt_epsilon])

    if args.attack in ['min_max','all']:
        for kkt_epsilon in kkt_epsilons:
            print("---- orig min-max attack summary of Eps {} ----".format(kkt_epsilon))
            print("[Final] Train Loss: {}, Train Error: {}, Test Loss: {}, Test Error: {}".format(best_min_max_train_losses[kkt_epsilon],\
            best_min_max_train_errors[kkt_epsilon], best_min_max_test_losses[kkt_epsilon],best_min_max_test_errors[kkt_epsilon]))
            print("Best Error Locations of Min-Max Attack:")
            print(best_target_errors_for_min_max[kkt_epsilon])

main(args)